{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Motivation\n",
    "\n",
    "In this notebook, we will perform a clustering task over a simplified version of the [TrackML](https://www.kaggle.com/c/trackml-particle-identification) dataset, wherein each sample is reduced from $\\approx10^5$ points to only $\\approx 10^2$ points. TrackML's goal is to reconstruct the trajectory of particles created when proton bunches accelerated to near the speed of light collide. To capture the particles' paths, the particles pass through many layers of a detector (e.g., Atlas), resulting in a point cloud. Successfully clustering this point cloud such that each cluster is associated to one particle is sufficient for physicists to then extract the particles' trajectories and thus discover their kinematic properties.\n",
    "\n",
    "This notebook presents a simple algorithm that uses a GNN to learn a good embedding for classical clustering algorithm like DBSCAN. While end-to-end and hierarchical clustering is highly desirable, it presents additional challenges which go beyond the scope of this tutorial's purpose in featuring DGL. Interested readers may check out the algorithm introduced in [End-to-End Hierarchical Clustering with Graph Neural Networks](https://cs.nyu.edu/media/publications/choma_nicholas.pdf)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pipeline\n",
    "\n",
    "Graph Neural Networks are an attractive candidate since they are capable of aggregating information over variable-sized neighborhoods of points; there is no limitation of modeling just singular or pairwise interactions. Because we begin only with a point cloud, our pipeline will be as follows:\n",
    "\n",
    "1. **Metric learning:** Embed the points into a Euclidean space using an MLP on each point indivually. Points belonging to the same particle should be close; points belonging to different particles should be far. *(provided as input data)*\n",
    "1. **Graph Construction:** Find each point's nearest neighbors in the embedded space, and connect them with a graph. Ensure neighborhoods are large enough to connect most of the points belonging to the same particle. *(provided as input data)*\n",
    "1. **GNN:** Embed the points into another Euclidean space, this time incorporating neighborhood information. Because each point is aware of its neighbors, the GNN produces a superior embedding to the metric learning stage.\n",
    "1. **Clustering:** Using the GNN's embedding, cluster the points with DBSCAN. Adjust DBSCAN's hyperparameters until the TrackML score is maximized."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loading Data\n",
    "\n",
    "The dataset consists of detection *events*, wherein a variable number of particles come into contact with the detector's cells. The location of these impacts are called *hits*, and there are likewise a variable number of them per particle."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "\n",
    "with open('tracks.pickle', 'rb') as f:\n",
    "    samples = pickle.load(f)\n",
    "\n",
    "print(\"Loaded {} samples.\".format(len(samples)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Features\n",
    "\n",
    "Each sample contains a set of hits, and each hit contains the following information:\n",
    "\n",
    "* *x,y,z* coordinates\n",
    "* Cell count and impact magnitude\n",
    "* A learned hit embedding, output from the previous graph creation stage\n",
    "* Ground truth cluster ID, denoting the particle which created the hit\n",
    "\n",
    "Additionally, samples contain graphs as output from the previous stage which aims to connect hits created by the same particle. The two graphs included are\n",
    "\n",
    "* A predicted graph, the raw output from the graph building stage\n",
    "* An augmented graph, which contains the predicted graph, plus any connections missed between hits created by the same particle (positive pairs). One could also add some negative pairs to the graph. This is used in the GNN's loss function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(samples[0].keys())\n",
    "print(samples[0]['hits'].keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualizations\n",
    "\n",
    "Choosing a sample to explore, one can see how the embedding differs from the raw features for graph creation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "def plot_clusters(x,y,pid):\n",
    "    for g in np.unique(pid):\n",
    "        i = np.where(pid == g)\n",
    "        plt.scatter(x[i],y[i], label=g)\n",
    "    plt.show()\n",
    "    plt.clf()\n",
    "\n",
    "hits = samples[1]['hits']\n",
    "xyz = hits['xyz']\n",
    "emb = hits['emb']\n",
    "pid = hits['particle_id']\n",
    "\n",
    "# Hit coordinates\n",
    "plot_clusters(xyz[:,0], xyz[:,1], pid)\n",
    "\n",
    "# Emb coordinates\n",
    "plot_clusters(emb[:,0], emb[:,1], pid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Embedding Model\n",
    "\n",
    "To create this embedding, an MLP is trained to trainsform one hit into its embedded representation. Hinge loss over pairs of embedded hits supervises the MLP's learning process. To keep this tutorial succinct, the MLP embedding is included in the dataset, the MLP having been already trained.\n",
    "\n",
    "Clearly, the embedding will lead to superior clustering as compared with the raw *x,y,z* positions. \n",
    "However, this embedding incorporates information from only each hit individually. \n",
    "With a GNN, one can create embeddings which incorporate information from the hit's neighborhood. \n",
    "As we will see, this allows for superior embeddings and thus improved performance in clustering."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Graph\n",
    "\n",
    "Once each point is embedded, the following steps occur in building a graph:\n",
    "\n",
    "1. Build a k-d tree with the embedded points.\n",
    "1. Query the neighborhood of every point, specifying either $k$, the number of neighbors, or $\\epsilon$, the size of the neighborhood around each point. In practice, $\\epsilon$-ball neighborhoods produce favorable results.\n",
    "1. Connect each point to its neighbors with an undirected graph.\n",
    "\n",
    "In keeping this tutorial succinct, these steps are already performed and the graph for each sample is included in the dataset.\n",
    "\n",
    "<img src=\"img/graph.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model\n",
    "\n",
    "The GNN model chosen is a simple message-passing architecture. One layer concatenates each node's features with an aggregation of the node's neighborhood, before applying a transformation via a fully-connected neural network layer.\n",
    "\n",
    "The output of the model is a set of node embeddings, where this new embedding has the same goal as in the graph building stage: according to some distance metric, node pairs whose hits belong to the same particle should be close, and otherwise they should be far."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl\n",
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Graph Weighting\n",
    "A multi-layer preceptron kernel determines edge weights of the graph at each layer.\n",
    "Here, each edge is represented by the features of its adjoining nodes.\n",
    "These features are passed through the MLP to produce an edge weight between 0 and 1.\n",
    "\n",
    "Consider the feature vectors of any two nodes, $x_i$ and $x_j$, which have an edge between them.\n",
    "Then their edge weight $w_{ij}$ is given as\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "w_{ij} &= \\text{sigmoid}(f([x_i, x_j])),\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "where $f$ is an MLP as defined in the graph weighting kernel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP_Kernel_DGL(nn.Module):\n",
    "    def __init__(self, nb_input, nb_hidden_gnn, nb_output=1, nb_layer=1):\n",
    "        super(MLP_Kernel_DGL, self).__init__()\n",
    "        layers = [nn.Linear(nb_input*2, nb_hidden_gnn)]\n",
    "        for _ in range(nb_layer-1):\n",
    "            layers.append(nn.Linear(nb_hidden_gnn, nb_hidden_gnn))\n",
    "        layers.append(nn.Linear(nb_hidden_gnn, nb_output))\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "        self.act1 = nn.ReLU()\n",
    "        self.act2 = nn.Sigmoid()\n",
    "\n",
    "    def forward(self, g):\n",
    "        g.apply_edges(self.mlp)\n",
    "        return g\n",
    "\n",
    "    def mlp(self, edges):\n",
    "        # Gather features from all relevant node pairs\n",
    "        src = edges.src['feat']\n",
    "        dst = edges.dst['feat']\n",
    "        e_feats = torch.cat((src,dst),dim=1)\n",
    "        \n",
    "        # Apply MLP layers to node pairs\n",
    "        for l in self.layers[:-1]:\n",
    "            e_feats = self.act1(l(e_feats))\n",
    "        \n",
    "        # Apply final output with sigmoid\n",
    "        e_feats = self.layers[-1](e_feats)\n",
    "        e_feats = self.act2(e_feats)\n",
    "        return {'e' : e_feats}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GNN Layer\n",
    "Each GNN layer first, if applicable, normalizes the incoming graph nodes, updates the graph weighting based on that layer's edge weight kernel, then applies graph convolution.\n",
    "\n",
    "Consider node $i$ at a given layer, and the neighborhood of $i$ given as $N_i$.\n",
    "Then the feature vector $x_i$ is updated to $x'_i$ in the following way.\n",
    "First, incoming messages to $i$ are weighted and aggregated:\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "z_i = \\sum_{j \\in N_i} w_{ij} * x_{j}.\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "Next, $x_i$ and $z_i$ are concatenated and transformed:\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "x'_i = \\text{ReLU}(\\phi([x_i, z_i])),\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "where $\\phi$ is a learned affine transformation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Complete GNN layer, including normalization, graph weighting, and convolution\n",
    "class GNN_Layer(nn.Module):\n",
    "    def __init__(self, input_dim, nb_hidden_gnn, nb_hidden_kernel, apply_norm=True):\n",
    "        super(GNN_Layer, self).__init__()\n",
    "\n",
    "        self.edge_weighting = MLP_Kernel_DGL(input_dim, nb_hidden_kernel)\n",
    "        self.bn = nn.BatchNorm1d(input_dim,momentum=0.10) if apply_norm else None\n",
    "        self.fc = nn.Linear(2*input_dim, nb_hidden_gnn)\n",
    "        self.act = nn.ReLU()\n",
    "\n",
    "    def forward(self, g, features):\n",
    "        # maybe apply normalization\n",
    "        if self.bn is not None:\n",
    "            features = self.bn(features)\n",
    "        g.ndata['feat'] = features\n",
    "\n",
    "        # set edge weights for this layer\n",
    "        g = self.edge_weighting(g)\n",
    "        \n",
    "        # send weighted messages and apply graph convolution to nodes\n",
    "        g.update_all(message_func=dgl.function.u_mul_e('feat', 'e', 'msg'),\n",
    "                     reduce_func=dgl.function.sum('msg', 'agg_msg'))\n",
    "        \n",
    "        # concat and apply an affine transformation\n",
    "        node_feats = torch.cat((features, g.ndata['agg_msg']), dim=1)\n",
    "        emb = self.fc(node_feats)\n",
    "        emb = self.act(emb)\n",
    "        \n",
    "        return emb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GNN\n",
    "The GNN consists of several layers as defined above. \n",
    "A final embedding layer takes the node features as output by the GNN layers, and applies an affine transformation to a low dimension space as defined by emb_dim."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GNN(nn.Module):\n",
    "    def __init__(self, nb_hidden_gnn, nb_layer, nb_hidden_kernel, input_dim, emb_dim=2):\n",
    "        super(GNN, self).__init__()\n",
    "\n",
    "        # Construct GNN Layers\n",
    "        gnn_layers = [GNN_Layer(input_dim, nb_hidden_gnn, nb_hidden_kernel, apply_norm=True)]\n",
    "        for _ in range(nb_layer-1):\n",
    "            gnn_layers.append(GNN_Layer(nb_hidden_gnn, nb_hidden_gnn, nb_hidden_kernel))\n",
    "        self.layers = nn.ModuleList(gnn_layers)\n",
    "\n",
    "        self.final_emb = nn.Linear(nb_hidden_gnn, emb_dim)\n",
    "\n",
    "    def forward(self, g):\n",
    "        emb = g.ndata.pop('feat')\n",
    "        for i, layer in enumerate(self.layers):\n",
    "            emb = layer(g, emb)\n",
    "        emb = self.final_emb(emb)\n",
    "        return emb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training\n",
    "\n",
    "## Dataset, Dataloader\n",
    "\n",
    "The TrackML_Dataset class is a PyTorch Dataset subclass, for use in a DataLoader class.\n",
    "The trackml_collate function should be used when instantiating the DataLoader class for minibatch training.\n",
    "\n",
    "Each sample will contain a graph (with features) used as input, and a graph with ground truth information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gnn_utils import TrackML_Dataset\n",
    "\n",
    "def trackml_collate(samples):\n",
    "    g_input = [s[0] for s in samples]\n",
    "    g_input = dgl.batch(g_input)\n",
    "\n",
    "    g_true = [s[1] for s in samples]\n",
    "    g_true = dgl.batch(g_true)\n",
    "\n",
    "    return g_input, g_true"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This collate function makes use of DGL's batch functionality, allowing computation over minibatches of variable-sized graphs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# PARAMETERS\n",
    "batch_size = 4\n",
    "nb_hidden = 32\n",
    "nb_layers = 4\n",
    "learn_rate = 0.001\n",
    "\n",
    "dataset = TrackML_Dataset(samples)\n",
    "dataloader = DataLoader(dataset, \n",
    "                        batch_size=batch_size, \n",
    "                        collate_fn=trackml_collate,\n",
    "                        drop_last=True, \n",
    "                        shuffle=True,\n",
    "                        num_workers=0)\n",
    "\n",
    "net = GNN(nb_hidden, nb_layers, nb_hidden, 6)\n",
    "optim = torch.optim.Adamax(net.parameters(), lr=learn_rate)\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loss\n",
    "Hinge embedding loss produces embeddings which are amenable to clustering, since they only penalize false pairs when the distance between false pairs becomes small.\n",
    "The hinge loss is stated as:\n",
    "\n",
    "$\n",
    "        l = \\begin{cases}\n",
    "            x, & \\text{if}\\; y = 1,\\\\\n",
    "            \\max \\{0, \\Delta - x\\}, & \\text{if}\\; y = -1,\n",
    "        \\end{cases}\n",
    "$\n",
    "where $x$ is the prediction measuring similarity, and $y$ is the target $\\in \\{-1,1\\}$.\n",
    "\n",
    "Here a DGL edge function is used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_emb_for_loss(edges):\n",
    "    src = edges.src['emb']\n",
    "    dst = edges.dst['emb']\n",
    "    pred_dist = nn.functional.pairwise_distance(src, dst)\n",
    "    truth = edges.data['truth']\n",
    "    true_dist = truth*2 - 1\n",
    "    loss = nn.functional.hinge_embedding_loss(pred_dist, true_dist, reduction='none')\n",
    "    return {'loss' : loss, 'pred_dist' : pred_dist, 'true_dist' : true_dist}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Accuracy\n",
    "An accuracy proxy helps as a sanity check during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_dist_accuracy(pred, true):\n",
    "    pred = pred.round()\n",
    "    pred[pred!=0] = 1\n",
    "    pred = 1-pred\n",
    "    correct = pred==true\n",
    "    nb_correct = correct.sum()\n",
    "    nb_total = true.size(0)\n",
    "    score = float(nb_correct.item()) / nb_total\n",
    "    return score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training\n",
    "Train over the dataset for a few epochs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_one_epoch(net, batch_size, optimizer, train_loader):\n",
    "    net.train()\n",
    "\n",
    "    nb_batch = len(train_loader)\n",
    "    nb_train = nb_batch * batch_size\n",
    "    epoch_score = 0\n",
    "    epoch_loss  = 0\n",
    "\n",
    "    print(\"\\nTraining on {} samples\".format(nb_train))\n",
    "    for i, (g_input, g_true) in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        # forward the GNN model on the input graph and set the result to the augmented graph\n",
    "        g_true.ndata['emb'] = net(g_input)\n",
    "        \n",
    "        # compute loss function over the augmented graph\n",
    "        g_true.apply_edges(get_emb_for_loss)\n",
    "        loss = g_true.edata.pop('loss').mean()\n",
    "        score = score_dist_accuracy(g_true.edata.pop('pred_dist'), g_true.edata.pop('truth'))\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        epoch_score += score * 100\n",
    "        epoch_loss  += loss.item()\n",
    "\n",
    "        nb_proc = (i+1) * batch_size\n",
    "        if (((i+1) % (nb_batch//2)) == 0):\n",
    "            print(\"  {:2d}  Loss: {:.3f}  Acc: {:2.1f}\".format(nb_proc, epoch_loss/(i+1), epoch_score/(i+1)))\n",
    "    return epoch_loss / nb_batch, epoch_score / nb_batch\n",
    "\n",
    "for i in range(20):\n",
    "    train_one_epoch(net, batch_size, optim, dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = DataLoader(dataset, batch_size=1, collate_fn=trackml_collate)\n",
    "\n",
    "# Embed samples\n",
    "orig_xyz = []\n",
    "emb_metric = []\n",
    "emb_gnn = []\n",
    "pid = []\n",
    "weight = []\n",
    "net.eval()\n",
    "with torch.autograd.no_grad():\n",
    "    for i, (g_input, g_true) in enumerate(dataloader):\n",
    "        f = g_input.ndata['feat']\n",
    "        pid.append(g_input.ndata['pid'])\n",
    "        orig_xyz.append(g_input.ndata['feat'][:,:3])\n",
    "        emb_metric.append(g_input.ndata['feat'][:,3:])\n",
    "        weight.append(g_input.ndata['weight'])\n",
    "        hits_emb = net(g_input)\n",
    "        emb_gnn.append(hits_emb.numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As a sanity check, visualize the tracks in the original space, metric learning embedding, and GNN embedding.\n",
    "Not that the metric learning model was trained for many hours on GPU, while training the GNN in this notebook is quite limited."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "j = 1\n",
    "to_plot = [orig_xyz[j], emb_metric[j], emb_gnn[j]]\n",
    "for emb in to_plot:\n",
    "    plot_clusters(emb[:,0], emb[:,1], pid[j])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DBSCAN\n",
    "Cluster the embeddings output by the GNN for final evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cluster\n",
    "from sklearn.cluster import DBSCAN\n",
    "\n",
    "c = DBSCAN(eps=.24, min_samples=3)\n",
    "\n",
    "def get_clusters(embedding):\n",
    "    return c.fit_predict(embedding)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scoring\n",
    "The score of one predicted cluster is nonzero only if a large majority of its points belong to the same true cluster, and if the majority of the true cluster is contained within the predicted cluster.  A perfect clustering will thus lead to a score of 1, while a random clustering will almost certainly have a score of 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Score samples\n",
    "from gnn_utils import score_event"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Get final cluster scores\n",
    "import pandas\n",
    "\n",
    "avg_score = 0.0\n",
    "nb_samples = 20\n",
    "for i in range(nb_samples):\n",
    "    # emb = samples[i]['hits']['emb']\n",
    "    # clusters = get_clusters(emb)\n",
    "    clusters = get_clusters(emb_gnn[i])\n",
    "    hit_ids = np.arange(len(clusters))\n",
    "    truth = pandas.DataFrame.from_dict({'particle_id':pid[i].numpy(),\n",
    "                                        'hit_id':hit_ids,\n",
    "                                        'weight':weight[i].numpy()})\n",
    "    submission = pandas.DataFrame.from_dict({'hit_id':hit_ids,\n",
    "                                             'track_id':clusters})\n",
    "    score = score_event(truth, submission)\n",
    "    avg_score += score\n",
    "print(\"TrackML score: {:.2f}\".format(avg_score / nb_samples))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
